{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"align-center\">\n",
    "<a href=\"https://oumi.ai/\"><img src=\"https://oumi.ai/docs/en/latest/_static/logo/header_logo.png\" height=\"200\"></a>\n",
    "\n",
    "[![Documentation](https://img.shields.io/badge/Documentation-latest-blue.svg)](https://oumi.ai/docs/en/latest/index.html)\n",
    "[![Discord](https://img.shields.io/discord/1286348126797430814?label=Discord)](https://discord.gg/oumi)\n",
    "[![GitHub Repo stars](https://img.shields.io/github/stars/oumi-ai/oumi)](https://github.com/oumi-ai/oumi)\n",
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/oumi-ai/oumi/blob/main/configs/projects/halloumi/halloumi_eval_notebook.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
    "</div>\n",
    "\n",
    "👋 Welcome to Open Universal Machine Intelligence (Oumi)!\n",
    "\n",
    "🚀 Oumi is a fully open-source platform that streamlines the entire lifecycle of foundation models - from [data preparation](https://oumi.ai/docs/en/latest/resources/datasets/datasets.html) and [training](https://oumi.ai/docs/en/latest/user_guides/train/train.html) to [evaluation](https://oumi.ai/docs/en/latest/user_guides/evaluate/evaluate.html) and [deployment](https://oumi.ai/docs/en/latest/user_guides/launch/launch.html). Whether you're developing on a laptop, launching large scale experiments on a cluster, or deploying models in production, Oumi provides the tools and workflows you need.\n",
    "\n",
    "🤝 Make sure to join our [Discord community](https://discord.gg/oumi) to get help, share your experiences, and contribute to the project! If you are interested in joining one of the community's open-science efforts, check out our [open collaboration](https://oumi.ai/community) page.\n",
    "\n",
    "⭐ If you like Oumi and you would like to support it, please give it a star on [GitHub](https://github.com/oumi-ai/oumi)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluating LLMs as Hallucination Classifiers\n",
    "\n",
    "This notebook demonstrates how we evaluated numerous Large Language Models (LLMs) as hallucination classifiers. \n",
    "To see our full evaluation results and learn more about our [groundedness benchmark](https://huggingface.co/datasets/oumi-ai/oumi-groundedness-benchmark), \n",
    "please see our [technical overview](https://oumi.ai/blog/posts/introducing-halloumi)!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites and Environment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install Oumi\n",
    "\n",
    "First, let's install Oumi. You can find more detailed instructions about Oumi installation [here](https://oumi.ai/docs/en/latest/get_started/installation.html).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install oumi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### API Keys\n",
    "\n",
    "To be able to make remote API calls to Open AI, Gemini, and Anthropic you will need the following keys."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"\"  # Set your OpenAI API key here.\n",
    "os.environ[\"GEMINI_API_KEY\"] = \"\"  # Set your Gemini API key here.\n",
    "os.environ[\"ANTHROPIC_API_KEY\"] = \"\"  # Set your  Anthropic API key here.\n",
    "\n",
    "# Set your GCP project id and region, to be able to query LLaMA 405B in Vertex.\n",
    "REGION = \"\"  # Set your GCP region here.\n",
    "PROJECT_ID = \"\"  # Set your GCP project id here."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Limit the Number of Examples\n",
    "\n",
    "This notebook limits the total number of examples to 10, for testing purposes. To use the full dataset, set the number of examples to `None`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_EXAMPLES = 10  # Replace with `None` for full dataset evaluation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation Dataset\n",
    "\n",
    "You can load the test split of [Oumi's Groundedness Benchmark](https://huggingface.co/datasets/oumi-ai/oumi-groundedness-benchmark) as follows. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from oumi.datasets.sft.prompt_response import PromptResponseDataset\n",
    "\n",
    "groundedness_dataset = PromptResponseDataset(\n",
    "    hf_dataset_path=\"oumi-ai/oumi-groundedness-benchmark\",\n",
    "    response_column=\"\",\n",
    "    split=\"test\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following code snippet extracts the labels of the dataset and converts them from `str`s (`SUPPORTED` or `UNSUPPORTED`) to `int`s (`0` or `1`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "groundedness_labels_str = groundedness_dataset.data.label.tolist()\n",
    "groundedness_labels = [\n",
    "    0 if label == \"SUPPORTED\" else 1 for label in groundedness_labels_str\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prediction Extraction\n",
    "\n",
    "The following function extracts the prediction from a free-text response generated by a LLM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _extract_prediction(response):\n",
    "    \"\"\"Returns: 0 if supported, 1 if unsupported.\"\"\"\n",
    "    is_unsupported = \"<|unsupported|>\" in response or \"<unsupported>\" in response\n",
    "    return 1 if is_unsupported else 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Custom Evaluation Function\n",
    "\n",
    "The following code snippet demonstrates how we register a custom evaluation function that produces the F1 score and the Balanced Accuracy score, given a set of prompts and their corresponding labels. The prompts have been provided in Oumi format (a list of `Conversation`s), while the labels are a list of `int`s (`0` or `1`). An inference engine is also instantiated by Oumi, based on the provided configuration, discussed below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from oumi.core.evaluation.metrics import bacc_score, f1_score\n",
    "from oumi.core.registry import register_evaluation_function\n",
    "\n",
    "\n",
    "@register_evaluation_function(\"hallucination_classification\")\n",
    "def hallucination_classification(inference_engine, conversations, labels):\n",
    "    \"\"\"Hallucination classification evaluation function.\"\"\"\n",
    "    # Run inference to append the model response into each conversation.\n",
    "    conversations = inference_engine.infer(conversations)\n",
    "\n",
    "    y_true, y_pred = [], []\n",
    "    for conversation, label in zip(conversations, labels):\n",
    "        # Extract the model response from the conversation.\n",
    "        response = conversation.last_message()\n",
    "\n",
    "        # Extract the prediction from the response.\n",
    "        prediction = _extract_prediction(response.content)\n",
    "\n",
    "        # Record the predictions together with their labels.\n",
    "        y_pred.append(prediction)\n",
    "        y_true.append(label)\n",
    "\n",
    "    # Compute relevant metrics and return them.\n",
    "    return {\n",
    "        \"f1\": f1_score(y_true, y_pred, average=\"macro\"),\n",
    "        \"bacc\": bacc_score(y_true, y_pred),\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation Configurations\n",
    "\n",
    "The following configurations represent the models we have evaluated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_names = [\n",
    "    \"halloumi 8b\",\n",
    "    # Uncomment any models you wish to evaluate - you can evaluate multiple at once.\n",
    "    # \"gpt_4o\",\n",
    "    # \"o1_preview\",\n",
    "    # \"gemini_pro\",\n",
    "    # \"llama_405b\",\n",
    "    # \"sonnet\",\n",
    "]\n",
    "\n",
    "configs = {\n",
    "    \"gpt_4o\": \"\"\"\n",
    "      model:\n",
    "        model_name: \"gpt-4o\"\n",
    "\n",
    "      inference_engine: OPENAI\n",
    "\n",
    "      inference_remote_params:\n",
    "        api_key_env_varname: \"OPENAI_API_KEY\"\n",
    "        max_retries: 3\n",
    "        num_workers: 100\n",
    "        politeness_policy: 60\n",
    "        connection_timeout: 300\n",
    "\n",
    "      generation:\n",
    "        max_new_tokens: 8192\n",
    "        temperature: 0.0\n",
    "\n",
    "      tasks:\n",
    "        - evaluation_backend: custom\n",
    "          task_name: hallucination_classification\n",
    "      \"\"\",\n",
    "    \"o1_preview\": \"\"\"\n",
    "      model:\n",
    "        model_name: \"o1-preview\"\n",
    "\n",
    "      inference_engine: OPENAI\n",
    "\n",
    "      inference_remote_params:\n",
    "        api_key_env_varname: \"OPENAI_API_KEY\"\n",
    "        max_retries: 3\n",
    "        num_workers: 100\n",
    "        politeness_policy: 60\n",
    "        connection_timeout: 300\n",
    "\n",
    "      generation:\n",
    "        max_new_tokens: 8192\n",
    "        temperature: 1.0\n",
    "\n",
    "      tasks:\n",
    "        - evaluation_backend: custom\n",
    "          task_name: hallucination_classification\n",
    "      \"\"\",\n",
    "    \"gemini_pro\": \"\"\"\n",
    "      model:\n",
    "        model_name: \"gemini-1.5-pro\"\n",
    "\n",
    "      inference_engine: GOOGLE_GEMINI\n",
    "\n",
    "      inference_remote_params:\n",
    "        api_key_env_varname: \"GEMINI_API_KEY\"\n",
    "        max_retries: 3\n",
    "        num_workers: 2\n",
    "        politeness_policy: 60\n",
    "        connection_timeout: 300\n",
    "\n",
    "      generation:\n",
    "        max_new_tokens: 8192\n",
    "        temperature: 0.0\n",
    "\n",
    "      tasks:\n",
    "        - evaluation_backend: custom\n",
    "          task_name: hallucination_classification\n",
    "      \"\"\",\n",
    "    \"llama_405b\": f\"\"\"\n",
    "      model:\n",
    "        model_name: \"meta/llama-3.1-405b-instruct-maas\"\n",
    "\n",
    "      inference_engine: GOOGLE_VERTEX\n",
    "\n",
    "      inference_remote_params:\n",
    "        api_url: \"https://{REGION}-aiplatform.googleapis.com/v1beta1/projects/{PROJECT_ID}/locations/{REGION}/endpoints/openapi/chat/completions\"\n",
    "        max_retries: 3\n",
    "        num_workers: 10\n",
    "        politeness_policy: 60\n",
    "        connection_timeout: 300\n",
    "\n",
    "      generation:\n",
    "        max_new_tokens: 8192\n",
    "        temperature: 0.0\n",
    "\n",
    "      tasks:\n",
    "        - evaluation_backend: custom\n",
    "          task_name: hallucination_classification\n",
    "      \"\"\",\n",
    "    \"sonnet\": \"\"\"\n",
    "      model:\n",
    "        model_name: \"claude-3-5-sonnet-latest\"\n",
    "\n",
    "      inference_engine: ANTHROPIC\n",
    "\n",
    "      inference_remote_params:\n",
    "        api_key_env_varname: \"ANTHROPIC_API_KEY\"\n",
    "        max_retries: 3\n",
    "        num_workers: 5\n",
    "        politeness_policy: 65\n",
    "        connection_timeout: 300\n",
    "\n",
    "      generation:\n",
    "        max_new_tokens: 8192\n",
    "        temperature: 0.0\n",
    "\n",
    "      tasks:\n",
    "        - evaluation_backend: custom\n",
    "          task_name: hallucination_classification\n",
    "      \"\"\",\n",
    "    \"halloumi 8b\": \"\"\"\n",
    "      model:\n",
    "        model_name: \"halloumi\"\n",
    "\n",
    "      inference_engine: REMOTE\n",
    "\n",
    "      inference_remote_params:\n",
    "        api_url: \"https://api.oumi.ai/chat/completions\"\n",
    "        max_retries: 3\n",
    "        connection_timeout: 300\n",
    "\n",
    "      generation:\n",
    "        max_new_tokens: 8192\n",
    "        temperature: 0.0\n",
    "\n",
    "      tasks:\n",
    "        - evaluation_backend: custom\n",
    "          task_name: hallucination_classification\n",
    "      \"\"\",\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Running the Evaluation\n",
    "\n",
    "Putting everything together: Below, we demonstrate how to run the evaluation for each model sequentially. For each model, we first create the evaluation configuration, convert the dataset prompts into Oumi format (i.e., `conversation`s), and then initiate the evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from oumi.core.configs import EvaluationConfig\n",
    "from oumi.core.evaluation import Evaluator\n",
    "\n",
    "results = {model_name: {} for model_name in model_names}\n",
    "\n",
    "for model_name in model_names:\n",
    "    # Create the evaluation config from the YAML string.\n",
    "    config_yaml: str = configs[model_name]\n",
    "    config = EvaluationConfig.from_str(config_yaml)\n",
    "\n",
    "    # Each model has a different (but very similar) prompt, thus we set the prompt\n",
    "    # column to relevant model name, before running the evaluation.\n",
    "    groundedness_dataset.prompt_column = f\"{model_name} prompt\"\n",
    "\n",
    "    # Convert the prompts into Oumi conversations.\n",
    "    conversations = groundedness_dataset.conversations()\n",
    "    if NUM_EXAMPLES:\n",
    "        conversations = conversations[:NUM_EXAMPLES]\n",
    "\n",
    "    # Run the evaluation.\n",
    "    evaluator = Evaluator()\n",
    "    evaluator_out = evaluator.evaluate(\n",
    "        config=config,\n",
    "        conversations=conversations,\n",
    "        labels=groundedness_labels,\n",
    "    )\n",
    "\n",
    "    # Record the results.\n",
    "    results[model_name] = evaluator_out[0].get_results()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inspect the Results\n",
    "\n",
    "Next, we inspect the performance metrics for each evaluated model. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_name in model_names:\n",
    "    print(\n",
    "        f\"{model_name}: BACC = {results[model_name]['bacc'].value:0.1%}\"\n",
    "        f\" ± {results[model_name]['bacc'].ci:0.1%}\"\n",
    "    )\n",
    "\n",
    "for model_name in model_names:\n",
    "    print(\n",
    "        f\"{model_name}: F1 = {results[model_name]['f1'].value:0.1%}\"\n",
    "        f\" ± {results[model_name]['f1'].ci:0.1%}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To wrap things up, we discussed how to use Oumi's custom evaluation framework to evaluate multiple open and closed source LLMs as hallucination classifiers, using the [newly-introduced groundedness benchmark](https://huggingface.co/datasets/oumi-ai/oumi-groundedness-benchmark).\n",
    "\n",
    "If you'd like to try inference on HallOumi, please also check out our [inference notebook](https://github.com/oumi-ai/oumi/blob/main/configs/projects/halloumi/halloumi_inference_notebook.ipynb)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "oumi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
